import math
import torch.nn as nn

from .unet import UNet

# The U-Net architecture was an early breakthrough for image segmentation.
#   - In the top row is the full resolution, the row below has half that, and
#     so on
#   - The data flows from top left to bottom center through a series of
#     convolutions and downscaling
#   - Then we go up again, using upscaling convolutions to get back to the full
#     resolution
#
# Earlier network designs already had this U-shape, which people attempted to
# use to address the limited receptive field size of fully convolutional
# networks. To address this limited field size, they used a design that copied,
# inverted, and appended the focusing portions of an image-classification
# network to create a symmetrical model that goes from fine detail to wide
# receptive field and back to fine detail.
#
# Those earlier network designs had problems converging, however, most likely
# due to the loss of spatial information during downsampling. Once information
# reaches a large number of very downscaled images, the exact location of
# object boundaries gets harder to encode and therefore reconstruct.
#
# To address this, the U-Net authors added the skip connections. In U-Net, skip
# connections short-circuit inputs along the downsampling path into the
# corresponding layers in the upsampling path. These layers receive as input
# both the upsampled results of the wide receptive field layers from lower in
# the U as well as the output of the earlier fine detail layers via the "copy
# and crop" bridge connections. This is the key innovation behind U-Net.
#
# All of this means those final detail layers are operating with the best of
# both worlds. They've got both information about the larger context surrounding
# the immediate area and fine detail data from the first set of full-resolution
# layers.

# Adapting an off-the-shelf model to our project
#   - Pass the input through batch normalization. This way, we won't have to
#     normalize the data ourselves in the dataset.
#   - Pass the output through an nn.Sigmoid layer to restrict the output to the
#     range [0, 1].
#   - Reduce the total depth and number of filters we allow our model to use.
#   - Our output is a single channel, with each pixel of output representing
#     the model's estimate of the probability that the pixel in question is
#     part of a nodule.


class UNetWrapper(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

        # U-Net is fundamentally a two-dimensional segmentation model. We could
        # adapt the implementation to use 3D convolutions, in order to use
        # information across slices. The memory usage of a straight-forward
        # implementation would be considerably greater: that is, we would have
        # to chop up the CT scan. Also, the fact that pixel spacing in the Z
        # direction is much larger than in-plane makes a nodule less likely to
        # be present across many slices. These considerations make a fully 3D
        # approach less attractive for our purposes. Instead, we'll adapt our
        # 3D data to be segmented a slice at a time, providing adjacent slices
        # for context. Since we're sticking with presenting the data in 2D,
        # we'll use channels to represent the adjacent slices.
        self.input_batchnorm = nn.BatchNorm2d(kwargs["in_channels"])
        self.unet = UNet(**kwargs)
        self.final = nn.Sigmoid()

        self._init_weights()

    def forward(self, input_batch):
        bn_output = self.input_batchnorm(input_batch)
        un_output = self.unet(bn_output)
        fn_output = self.final(un_output)
        return fn_output

    def _init_weights(self):
        init_set = {
            nn.Conv2d,
            nn.Conv3d,
            nn.ConvTranspose2d,
            nn.ConvTranspose3d,
            nn.Linear,
        }
        for m in self.modules():
            if type(m) in init_set:
                nn.init.kaiming_normal_(
                    m.weight.data, mode="fan_out", nonlinearity="relu", a=0
                )
                if m.bias is not None:
                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(
                        m.weight.data
                    )
                    bound = 1 / math.sqrt(fan_out)
                    nn.init.normal_(m.bias, -bound, bound)
